import matplotlib.pyplot as plt
import seaborn as sns
from model import datatable
from regressions import random_forest, knn_regression, decision_tree, logistic_regression, svm_classification, \
    nbayes_classification, ann_network
from sklearn import metrics
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import MinMaxScaler
import pandas as pd


def plotter(test, train):
    evaluations = []
    evaluations_r = []
    df, df_test = cleaning(test, train)
    fig, ax = plt.subplots(figsize=(40, 40))
    sns.heatmap(df.corr(), annot=True, cmap='Blues', ax=ax)
    plt.show()
    y_test, y_pred = random_forest(df, df_test)
    cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0, 1])
    plot('Random Forest', disp)
    evaluations, evaluations_r = evaluator(y_test, y_pred, evaluations, evaluations_r, "Random Forest")
    y_test, y_pred = knn_regression(df, df_test)
    cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0, 1])
    plot('KNN', disp)
    evaluations, evaluations_r = evaluator(y_test, y_pred, evaluations, evaluations_r, "KNN")
    y_test, y_pred = decision_tree(df, df_test)
    cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0, 1])
    plot('Decision Tree', disp)
    evaluations, evaluations_r = evaluator(y_test, y_pred, evaluations, evaluations_r, "Decision Tree")
    y_test, y_pred = logistic_regression(df, df_test)
    cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0, 1])
    plot('Logistic Regression', disp)
    evaluations, evaluations_r = evaluator(y_test, y_pred, evaluations, evaluations_r, "Logistic Regression")
    y_test, y_pred = svm_classification(df, df_test)
    cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0, 1])
    plot('SVM', disp)
    evaluations, evaluations_r = evaluator(y_test, y_pred, evaluations, evaluations_r, "SVM")
    y_test, y_pred = nbayes_classification(df, df_test)
    cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0, 1])
    plot('Naive Bayes', disp)
    evaluations, evaluations_r = evaluator(y_test, y_pred, evaluations, evaluations_r, "Naive Bayes")
    y_test, y_pred = ann_network(df, df_test)
    cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0, 1])
    plot('ANN', disp)
    evaluations, evaluations_r = evaluator(y_test, y_pred, evaluations, evaluations_r, "ANN")
    final_evaluation(evaluations, "Accuracy", "Green")
    final_evaluation(evaluations_r, "F1", "Purple")


def plot(z, k):
    k.plot()
    k.ax_.set_title(z)
    plt.show()


def cleaning(test, train):
    df, df_test = datatable(test, train)
    to_drop = []
    to_drop_2 = []
    for col in df.columns:
        if len(df[col].unique()) == 1:
            to_drop.append(col)
    df.drop(to_drop, axis=1, inplace=True)
    for col in df.columns:
        if len(df[col].unique()) == 1:
            to_drop_2.append(col)
    df.drop(to_drop_2, axis=1, inplace=True)
    cols_to_norm = ['X1_ActualPosition', 'X1_ActualVelocity', 'X1_ActualAcceleration', 'X1_CommandPosition',
                    'X1_CommandVelocity', 'X1_CommandAcceleration', 'X1_CurrentFeedback', 'X1_DCBusVoltage',
                    'X1_OutputCurrent', 'X1_OutputVoltage', 'X1_OutputPower', 'Y1_ActualPosition', 'Y1_ActualVelocity',
                    'Y1_ActualAcceleration', 'Y1_CommandPosition', 'Y1_CommandVelocity', 'Y1_CommandAcceleration',
                    'Y1_CurrentFeedback', 'Y1_DCBusVoltage', 'Y1_OutputCurrent', 'Y1_OutputVoltage', 'Y1_OutputPower',
                    'Z1_ActualPosition', 'Z1_ActualVelocity', 'Z1_ActualAcceleration', 'Z1_CommandPosition',
                    'Z1_CommandVelocity', 'Z1_CommandAcceleration',
                    'S1_ActualPosition', 'S1_ActualVelocity',
                    'S1_ActualAcceleration', 'S1_CommandPosition', 'S1_CommandVelocity', 'S1_CommandAcceleration',
                    'S1_CurrentFeedback', 'S1_DCBusVoltage', 'S1_OutputCurrent', 'S1_OutputVoltage', 'S1_OutputPower',
                    'M1_sequence_number', 'M1_CURRENT_FEEDRATE',
                    'Machining_Process', 'feedrate', 'clamp_pressure',
                    'machining_finalized', 'passed_visual_inspection']
    df[cols_to_norm] = MinMaxScaler().fit_transform(df[cols_to_norm])
    df.dropna(axis=0, how="any", subset=['tool_condition'], inplace=True)
    df = df.fillna("", inplace=False)
    df_test[cols_to_norm] = MinMaxScaler().fit_transform(df_test[cols_to_norm])
    df_test.dropna(axis=0, how="any", subset=['tool_condition'], inplace=True)
    df_test = df_test.fillna("", inplace=False)
    print(df)
    labels = ['not worn', 'worn']
    df['tool_condition'].value_counts().plot(kind='pie', labels=labels)
    print(df_test)
    return df, df_test


def evaluator(y_test, y_pred, evaluations, evaluations_r, method):
    precision = metrics.precision_score(y_test, y_pred)
    accuracy = metrics.accuracy_score(y_test, y_pred)
    recall = metrics.recall_score(y_test, y_pred)
    f1 = metrics.f1_score(y_test, y_pred)
    print(pd.DataFrame([precision, accuracy, recall, f1], index=['Precision', 'Accuracy', 'Recall', 'F1'],
                       columns=[method]))
    evaluations.append(accuracy)
    evaluations_r.append(f1)
    return evaluations, evaluations_r


def final_evaluation(evaluations, method, color):
    data = {'Random Forest': evaluations[0], 'KNN': evaluations[1],
            'Decision Tree': evaluations[2], 'Logistic Regression': evaluations[3], 'SVM': evaluations[4],
            'Naive Bayes': evaluations[5], 'ANN': evaluations[6]}
    Names = list(data.keys())
    values = list(data.values())
    fig = plt.figure(figsize=(10, 5))
    plt.bar(Names, values, color=color,
            width=0.4)
    plt.xlabel("Methods")
    plt.ylabel(method)
    plt.title("Evaluation of the methods selected")
    plt.show()
